import os

import pandas as pd
import torch

from Image_Mediator_Training.imageMediator_graph import set_imageMediator
from ModularUtils.Experiment_Class import Experiment
from ModularUtils.FunctionsConstant import asKey, getdoKey, get_dataset
from ModularUtils.FunctionsDistribution import get_joint_distributions_from_samples, calculate_TVD, calculate_KL
from ModularUtils.ControllerConstants import map_dictfill_to_discrete
from ModularUtils.ControllerModel import get_generated_labels, get_fake_distribution
from ModularUtils.DigitImageGeneration.mnist_image_generation import plot_trained_digits
from ModularUtils.FrontBackDoorCalculation import estiamte_ate_backdoor_direct
from ModularUtils.FunctionsTraining import get_training_variables, save_results
from ModularUtils.Functions_Plot_Results import plot_saved_results


def imageMediatorEvaluation(Exp, cur_hnodes, label_generators, dataset_dict, tvd_diff, kl_diff):
    for gen in label_generators:
        label_generators[gen].eval()

    with torch.no_grad():
        # observational tvd for each mechanisms so that we can notice that mechanism learning

        feat = "feature"
        all_generated_labels={}
        all_real_labels={}


        # for hn, cur_mechs in cur_hnodes.items():
        # for compare_Var in [["C"], ["D", "C"]]:


        for query in Exp.interv_queries:

            for key in query["intervs"]:
                compare_Var= query["obs"]
                # for interv_no, key in enumerate(Exp.Data_intervs):
                intv_key = asKey(key)
                query_str = getdoKey(compare_Var, dict(intv_key))

                if key=={}:
                    # continue

                    if len(compare_Var)==0:
                        continue

                    _, _, _, graph_label_vars = get_training_variables(Exp, Exp.label_names, 0, key)
                    obs_indices = [graph_label_vars.index(lb) for lb in compare_Var]
                    current_real_label = []
                    if intv_key in dataset_dict:
                        current_real_label = dataset_dict[intv_key]["obs"][:, obs_indices].type(torch.LongTensor).view(-1,len(obs_indices)).to(Exp.DEVICE)


                    fake_dist_dict= get_fake_distribution(Exp, label_generators, intv_key, compare_Var)
                    dataset_dist_dict = get_joint_distributions_from_samples(Exp, compare_Var,
                                                                             current_real_label.detach().cpu().numpy().astype(
                                                                                 int), "feature")

                    # true_dist_dict = get_intv_dist(Exp, compare_Var, dict(intv_key), query_str, load_prev=False)
                    # dataset_dist_dict=true_dist_dict

                    obs_tvd = calculate_TVD(fake_dist_dict, dataset_dist_dict, doPrint=False)
                    obs_kl= calculate_KL(fake_dist_dict, dataset_dist_dict, doPrint=False)

                    if query_str in tvd_diff:
                        tvd_diff[query_str].append(round(obs_tvd, 4))  # todo: fix it
                        kl_diff[query_str].append(round(obs_kl, 4))

                    # print("")


                # elif query["expr"]=="P(genC|do(D))":
                else:
                    fake_dist_dict= get_fake_distribution(Exp, label_generators, key, compare_Var)
                    print('fake intv dist_dict',fake_dist_dict)
                    D = get_dataset(Exp, 'medD', 0)
                    U0 = get_dataset(Exp, 'medU0', 0)
                    genC = get_dataset(Exp, 'medC', 0)
                    cur_data = torch.cat([U0, D, genC], 1).cpu().numpy()
                    px = pd.DataFrame(cur_data)
                    px = px.rename(columns={0: 'medU0', 1: 'medD', 2: 'medC'})
                    dataset_dist_dict = estiamte_ate_backdoor_direct(Exp, px, 'medD', 'medC', ['medU0'])[list(key.values())[0]]
                    dataset_dist_dict= {tuple([key]):val for key,val in dataset_dist_dict.items()}
                    obs_tvd = calculate_TVD(fake_dist_dict, dataset_dist_dict, doPrint=False)
                    obs_kl = calculate_KL(fake_dist_dict, dataset_dist_dict, doPrint=False)

                    if query_str in tvd_diff:
                        tvd_diff[query_str].append(round(obs_tvd, 4))  # todo: fix it
                        kl_diff[query_str].append(round(obs_kl, 4))


        #####-----------
        # if set(all_compare_Var) & set(Exp.image_labels) !=set():
            # compare_Var = cur_mechs[0:-1]

            # compare_Var=["D", "C"]

                showImage=True
                if key=={} and showImage and (Exp.curr_epoochs+1)%1==0:
                    minibatch = 2
                    generated_labels_dict = get_generated_labels(Exp, label_generators, {}, {}, {}, compare_Var+[Exp.image_labels[0]], minibatch, hard=True)
                    generated_image = generated_labels_dict[Exp.image_labels[0]]
                    del generated_labels_dict[Exp.image_labels[0]]

                    # y_dims = sum([Exp.label_dim[lb]["feature"] for lb in compare_Var])
                    # ret = list(generated_labels_dict.values())
                    # generated_labels_ig = torch.cat(ret, 1).view(-1, y_dims)
                    generated_labels_ig = map_dictfill_to_discrete(Exp, generated_labels_dict, compare_Var)

                    for grow, genimg in zip(generated_labels_ig, generated_image):
                        print("gen", grow)
                        genimg = genimg.permute(1, 2, 0).detach().cpu().numpy()
                        # plot_dataset_digits(1, 2, [obsimg, genimg], f'Real {Ores_digit[id]}')


                        cur_fold=os.getcwd()
                        plot_trained_digits(1, 1, [genimg], f'Real {grow}', f'{cur_fold}/PLOTS')


        save_results(Exp, Exp.SAVED_PATH, all_generated_labels ,all_real_labels,
                     tvd_diff, kl_diff, Exp.G_avg_losses, Exp.D_avg_losses)



    for gen in label_generators:
        label_generators[gen].train()

    ll = -min(10, len(list(tvd_diff.values())[0]))
    # printing loss
    for dist in tvd_diff:
        print("###", dist, " loss%:",  [round(val, 4) for val in tvd_diff[dist][ll:]])
    print(Exp.SAVED_PATH)

    return tvd_diff , kl_diff




Exp = Experiment("Exp1", set_imageMediator,
                 new_experiment=False,
                 features=["feature"],
                 Data_intervs=[{}])
#


##############---------############
# -----------------------------
# pre_labels=['$P(D,A)$', 'ncm$P(D,A)$', '$P(A|do(D=0)$', '$ncmP(A|do(D=0)$',  '$P(A|do(D=1)$', '$ncmP(A|do(D=1)$']
root = f"/path_to_project/SAVED_EXPERIMENTS/imageMediator/Exp1"
exp='May_01_2023-11_54'
bnc_exp=['May_01_2023-04_42', 'May_10_2023-08_54_fullrep']
pre_labels= ['$P(D,A)$', '$ncmP(D,A)$', 'rep$P(D,A)$',
             'P(A|do(D=0))', 'ncmP(A|do(D=0))', 'repP(A|do(D=0))',
             'P(A|do(D=1))', 'ncmP(A|do(D=1))', 'repP(A|do(D=1))']

last_exp= f"{root}/{exp}"
benchmarks=[('ncm', f'{root}/{bnc_exp[0]}'), ('rep', f'{root}/{bnc_exp[1]}')]

plot_saved_results(Exp, None, epochs=300, delta=10,
               pre_labels=pre_labels, benchmarks=benchmarks)  #only whatifgan




